{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Working with pre-extracted embeddings\n",
    "\n",
    "When you're testing a Few-Shot Learning models, you are going to solve hundreds of randomly sampled few-shot tasks. In doing so, you are more than likely to process the same images several times. This means that these images will go through your backbone several times, which is a waste in time and energy. Indeed, most Few-Shot Learning methods nowadays make use of a **frozen backbone**: the logic of these methods is at the feature level. Therefore, you can extract the features of your images once and for all, and then use these features to solve your few-shot tasks.\n",
    "\n",
    "All the necessary tools to do so are available in EasyFSL. In this tutorial, we will show you how to use them.\n",
    "\n",
    "## Extracting the features\n",
    "\n",
    "EasyFSL has a `predict_embeddings()` method, which takes as input a DataLoader and a torch Module, and outputs a DataFrame with all your embeddings. Let's use it to extract all the embeddings from the test set of the CUB dataset. For a backbone, we are going to use the Swin Transformer pre-trained on ImageNet and directly available from torchvision. Note that we can do that because there is no intersection between CUB and ImageNet, so we are not technically cheating. Still, the resulting performance cannot be compared with that of a model trained on CUB's train set, since the training data is not the same.\n",
    "\n",
    "First do some necessary configuration (this is not the interesting part)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "try:\n",
    "    import google.colab\n",
    "\n",
    "    colab = True\n",
    "except:\n",
    "    colab = False\n",
    "\n",
    "if colab is True:\n",
    "    # Running in Google Colab\n",
    "    # Clone the repo\n",
    "    !git clone https: // github.com / sicara / easy-few-shot-learning\n",
    "    % cd easy-few-shot-learning\n",
    "    !pip install.\n",
    "else:\n",
    "    # Run locally\n",
    "    # Ensure working directory is the project's root\n",
    "    # Make sure easyfsl is installed!\n",
    "    %cd .."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Then we prepare the data and the model."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Download CUB if necessary\n",
    "!make download-cub"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.models\n",
    "from easyfsl.datasets import CUB\n",
    "\n",
    "batch_size = 128\n",
    "num_workers = 1\n",
    "\n",
    "# Ensure that you're working from the root of the repository\n",
    "# and that CUB's images are in ./data/CUB/images/\n",
    "dataset = CUB(split=\"test\", training=False)\n",
    "dataloader = DataLoader(\n",
    "    dataset,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    "    shuffle=False,\n",
    ")\n",
    "\n",
    "model = torchvision.models.swin_v2_t(\n",
    "    weights=torchvision.models.Swin_V2_T_Weights.IMAGENET1K_V1,\n",
    ")\n",
    "# Remove the classification head: we want embeddings, not ImageNet predictions\n",
    "model.head = nn.Flatten()\n",
    "\n",
    "# If you have a GPU, use it!\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model = model.to(device)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "And now we extract the embeddings. This gives us a DataFrame with the embeddings of all the images in the test set, along with their respective class_names."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from easyfsl.utils import predict_embeddings\n",
    "\n",
    "embeddings_df = predict_embeddings(dataloader, model, device=device)\n",
    "\n",
    "print(embeddings_df)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We now have our embeddings ready to use! We will not use the backbone anymore.\n",
    "\n",
    "## Performing inference on pre-extracted embeddings\n",
    "\n",
    "To deliver the embeddings to our Few-Shot Classifier, we will need an appropriate DataLoader. We will use the `FeaturesDataset` class from EasyFSL. Since we have a DataFrame ready to use, we will use the handy `from_dataset()` initializer from `FeaturesDataset`, but you can also use `from_dict()` to initialize from a dictionary, or the built-in constructor to initialize it directly from labels and embeddings."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from easyfsl.datasets import FeaturesDataset\n",
    "\n",
    "features_dataset = FeaturesDataset.from_dataframe(embeddings_df)\n",
    "\n",
    "print(features_dataset[0])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Then, like in all other few-shot tutorials, we are going to build a DataLoader that loads batches in the shape of few-shot tasks:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from easyfsl.samplers import TaskSampler\n",
    "\n",
    "task_sampler = TaskSampler(\n",
    "    features_dataset,\n",
    "    n_way=5,\n",
    "    n_shot=5,\n",
    "    n_query=10,\n",
    "    n_tasks=100,\n",
    ")\n",
    "features_loader = DataLoader(\n",
    "    features_dataset,\n",
    "    batch_sampler=task_sampler,\n",
    "    num_workers=num_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=task_sampler.episodic_collate_fn,\n",
    ")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We now need to instantiate our Few-Shot Classifier. We will use a Prototypical Network for simplicity, but you can use any other model from EasyFSL.\n",
    "\n",
    "Since we are working directly on features, **we don't need to initialize Prototypical Networks with a backbone**."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from easyfsl.methods import PrototypicalNetworks\n",
    "\n",
    "# Default backbone if we don't specify anything is Identity.\n",
    "# But we specify it anyway for clarity and robustness.\n",
    "few_shot_classifier = PrototypicalNetworks(backbone=nn.Identity())"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can now evaluate our model on the test set, and just enjoy how fast it goes:"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from easyfsl.utils import evaluate\n",
    "\n",
    "accuracy = evaluate(\n",
    "    few_shot_classifier,\n",
    "    features_loader,\n",
    "    device=\"cpu\",\n",
    ")\n",
    "\n",
    "print(f\"Average accuracy : {(100 * accuracy):.2f} %\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "And that is it! Notice that when you're working on pre-extracted embeddings, you can process tasks way faster (65 tasks/s on my MacBook Pro). This should always be your default settings whenever you're working with a method that uses a frozen backbone at test-time (that's most of them).\n",
    "\n",
    "## Conclusion\n",
    "Thanks for following this tutorial. If you have any issue, please [raise one](https://github.com/sicara/easy-few-shot-learning/issues), and if EasyFSL is helping you, do not hesitate to [star the repository](https://github.com/sicara/easy-few-shot-learning)."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}